import gurobipy as gp
from gurobipy import GRB

class SeparationProblem:
    def __init__(self, sol_X, K, C_set):
        self.sol_X = sol_X
        self.K = K
        self.C_set = C_set

    def define_separation_problem(self):
        """Define the separation problem."""
        n = len(self.sol_X)
        self.model = gp.Model('Separation Problem')
        self.model.Params.OutputFlag = 0
        self.model.Params.Threads = 1
        # self.model.setParam("LogFile", "./logs/gurobi.log")
        self.model.update()
        self.y = [self.model.addVars(len(self.K[i]), lb=0, ub=1, vtype=GRB.BINARY, name=f'y{i}') for i in range(n)]
        self.z = self.model.addVars(n, lb=0, ub=1, vtype=GRB.BINARY, name='Z')

        # Define objective and constraints
        obj = gp.quicksum(self.y[i][j] * self.sol_X[i][j] for i in range(n) for j in range(len(self.K[i]))) - gp.quicksum(self.z[i] for i in range(n))
        self.model.setObjective(obj, GRB.MAXIMIZE)

        for i in range(n):
            for j in range(len(self.K[i])):
                self.model.addConstr(self.y[i][j] - self.z[i] <= 0)
                pattern = self.K[i][j, :]
                s = gp.quicksum(self.z[ii] * pattern[ii] for ii in range(n))
                self.model.addConstr(self.y[i][j] - s <= 0)

        self.model.addConstr(gp.quicksum(self.z[i] for i in range(n)) >= 2)

        return self.model, self.z
    
    def solve(self):
        n = len(self.sol_X)
        """Solve the separation problem."""
        separation_problem, z_var = self.define_separation_problem()
        separation_problem.optimize()
        if separation_problem.status != GRB.OPTIMAL:
            raise RuntimeError("Optimization failed")

        C = []
        for i in range(n):
            if z_var[i].X == 1:
                C.append(i)
        C = set(C)
        if C in self.C_set:
            self.have_cycle = False
        else:
            self.C_set.append(C)
            self.have_cycle = True
        return self.have_cycle, self.C_set
        